import numpy as np
import pandas as pd
from collections import Iterable
from .orm_db import DataBase
from .orm.factorengine import *


class Factordb(DataBase):

    def __init__(self, engine):
        super(Factordb, self).__init__(engine)
        self.engine = engine
        self.factor_list = self._get_factor_list()

    def _get_factor_list(self):
        total_factor = self.query(
            self.FACTORS_INFO.NAME).to_df()
        return total_factor['name'].tolist()

    def get_factor(self, sid_list=None, factor_list=None, begin_dt='20190101', end_dt='20190501'):

        if factor_list is None:
            factor_list = self.factor_list

        df_list = []
        for factor in factor_list:
            table = getattr(self, factor.upper())

            tmp_query = self.query(table.sid, table.trade_dt,
                                      table.value).filter(
                table.trade_dt >= begin_dt,
                table.trade_dt <= end_dt).order_by(
                table.trade_dt, table.sid)
            if isinstance(sid_list, str):
                tmp_query = tmp_query.filter(table.sid == sid_list)
            elif isinstance(sid_list, Iterable):
                tmp_query = tmp_query.filter(table.sid.in_(sid_list))
            else:
                pass
            df = tmp_query.to_df()
            df.columns = ['sid', 'trade_dt', factor]
            df_ = df.set_index(['sid', 'trade_dt'])
            df_list.append(df_)
        df_summary = pd.concat(df_list, axis=1)
        return df_summary.reset_index()

FactorAPI = Factordb(factor_engine)




























